import numpy as np
import PIL.Image
import io
import base64
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
import time


def load_data_lists(split, save_dir):
    if split == "train":

        data = np.load(f'{save_dir}/{split}_data.npz', allow_pickle=True)
        return (
            data['fmri'],
            data['images'],
            data['captions'].tolist() if len(data['captions']) > 0 else [],
            data['coco_ids'].tolist(),
            data['responses'].tolist()
        )
    else:
        data = np.load(f'{save_dir}/{split}_data.npz', allow_pickle=True)

        return (
            data['fmri'],
            [PIL.Image.fromarray(img) for img in data['images']],
            data['captions'].tolist() if len(data['captions']) > 0 else [],
            data['coco_ids'].tolist(),
            data['responses'].tolist()
        )



def convert_pil_to_base64(pil_image):
    buffer = io.BytesIO()
    pil_image.save(buffer, format='JPEG')  # Saves to memory buffer, not disk
    return base64.b64encode(buffer.getvalue()).decode('utf-8')


def gpt_response(img, caption, keyword, max_attempts=10, retry_delay=5):
    img = convert_pil_to_base64(img)
    if caption:
        prompt = f'''
        Given the image and caption, first give me the background color style of the image with 3-5 words. Secondly, detect the TWO most important objects in the image. Then, describe each of the objects using the keyword: {keyword} as follow with TWO sentences. For each sentence, use 5-10 words and as easy as possible.
        Then, detect the absolute position of the two objects in the image, select from [right, left, top, bottom]. "left" and "right" should appear together for horizontal objects, and "top" and "bottom" should appear together for vertical objects. DO NOT mix.
        ### Background color styl: Grayscale urban.
        
        ### The Man [left]
        1. The man is standing near the sidewalk edge. The Man is close to the building wall.
    
        ### The Suitcase [right]
        1. The suitcase is beside the man's foot. The Suitcase is placed on the street's curved edge.
    
        Now, given the image I uploaded and the caption "{caption}", describe the two most important objects using the keyword {keyword} with EXACTLY the example format:
        '''
    else:
        prompt = f'''
                Given the image, first give me the background color style of the image with 3-5 words. Secondly, detect the TWO most important objects in the image. Then, describe each of the objects using the keyword: {keyword} as follow with TWO sentences. For each sentence, use 5-10 words and as easy as possible.

                ### Background color styl: Grayscale urban.

                ### The Man
                1. The man is standing near the sidewalk edge. The Man is close to the building wall.


                ### The Suitcase
                1. The suitcase is beside the man's foot. The Suitcase is placed on the street's curved edge.

                Now, given the image I uploaded, describe the two most important objects using the keyword {keyword} with EXACTLY the example format:
                '''

    model = ChatOpenAI(model="gpt-4o-mini",
                       openai_api_key="xxx",
                       temperature=0,
                       max_tokens=None,
                       timeout=None,
                       max_retries=2)

    message = HumanMessage(
        content=[
            {"type": "text", "text": prompt},
            {
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{img}"},
            },
        ],
    )

    attempts = 0

    while attempts < max_attempts:
        attempts += 1
        try:
            response = model.invoke([message])
            if response and response.content and len(response.content.strip()) > 0:
                return response.content
            else:
                print(f"Empty response received on attempt {attempts}. Retrying...")
        except Exception as e:
            print(f"Error on attempt {attempts}: {str(e)}")

        if attempts < max_attempts:
            print(f"Waiting {retry_delay} seconds before retry...")
            time.sleep(retry_delay)

    raise Exception(f"Failed to get response after {max_attempts} attempts")


def process_item(cur_caption, cur_img, keyword):
    response = gpt_response(cur_img, cur_caption, keyword)
    return response


import concurrent.futures
def get_data_responses(caption_list, img_list, keyword):
    if len(caption_list) == 0: # test data does not have caption
        with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
            futures = [executor.submit(process_item, None, img_list[i], keyword) for i in range(len(img_list))]
            response_list = [future.result() for future in futures]
    else:
        with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
            futures = [executor.submit(process_item, caption_list[i], img_list[i], keyword) for i in range(len(img_list))]
            response_list = [future.result() for future in futures]
    return response_list



# load test
split = "test"
save_dir = "data/"

data = load_data_lists(split, save_dir)
fmri_test = data[0]
img_test = data[1]
caption_test = data[2]
coco_id_test = data[3]
print("Data loaded! Length: ", len(fmri_test), len(img_test), len(caption_test), len(coco_id_test))

# load train
train_data = np.load("data.npz", allow_pickle=True)
fmri_train = train_data['train_fmri']

caption_list_train = train_data['caption_list_train']
coco_id_train = train_data['coco_id_list_train']
img_list_train = train_data['img_list_train']
print(len(img_list_train))


# keyword = "Spatial Associations"
keyword = "Spatial layout"


# test
print("Processing: ", split)
caption_list = caption_test
img_list = img_test
response_list = get_data_responses(caption_list, img_list, keyword)
print(len(response_list))
# save
np.savez(f'{save_dir}/{split}_proposed_description.npz', response_list=response_list)
# load the data
# response_list = np.load(f'{save_dir}/subj1/brain_data_{split}_proposed_description.npz', allow_pickle=True)

# train
split = "train"
print("Processing: ", split)
img_list = img_list_train
caption_list = caption_list_train
response_list = get_data_responses(caption_list, img_list, keyword)
print(len(response_list))
# save
np.savez(f'{save_dir}/{split}_proposed_description.npz', response_list=response_list)


